{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import ray\n",
    "from ray import tune\n",
    "\n",
    "import torch # to remove later\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import models\n",
    "import networks\n",
    "from datasets import PreprocessedSpeechDataLoader, VaryingDataLoader\n",
    "from nupic.research.frameworks.pytorch.image_transforms import RandomNoise\n",
    "\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = dict(\n",
    "    device=(\"cuda\" if torch.cuda.device_count() > 0 else \"cpu\"),\n",
    "    dataset_name=\"PreprocessedGSC\",\n",
    "    data_dir=\"~/nta/datasets/gsc\",\n",
    "    batch_size_train=(4, 16),\n",
    "    batch_size_test=1000,\n",
    "\n",
    "    # ----- Network Related ------\n",
    "    # SE\n",
    "    # model=tune.grid_search([\"BaseModel\", \"SparseModel\", \"DSNNMixedHeb\", \"DSNNConvHeb\"]),\n",
    "    model=\"DSNNConvHeb\",\n",
    "    network=\"gsc_conv_heb\",\n",
    "\n",
    "    # ----- Optimizer Related ----\n",
    "    optim_alg=\"SGD\",\n",
    "    momentum=0,\n",
    "    learning_rate=0.01,\n",
    "    weight_decay=0.01,\n",
    "    lr_scheduler=\"StepLR\",\n",
    "    lr_gamma=0.90,\n",
    "    use_kwinners = True,\n",
    "    # use_kwinners=tune.grid_search([True, False]),\n",
    "\n",
    "    # ----- Dynamic-Sparse Related  - FC LAYER -----\n",
    "    epsilon=184.61538/3, # 0.1 in the 1600-1000 linear layer\n",
    "    sparse_linear_only = True,\n",
    "    start_sparse=1,\n",
    "    end_sparse=-1, # don't get last layer\n",
    "    weight_prune_perc=0.15,\n",
    "    hebbian_prune_perc=0.60,\n",
    "    pruning_es=True,\n",
    "    pruning_es_patience=0,\n",
    "    pruning_es_window_size=5,\n",
    "    pruning_es_threshold=0.02,\n",
    "    pruning_interval=1,\n",
    "\n",
    "    # ----- Dynamic-Sparse Related  - CONV -----\n",
    "    prune_methods=['dynamic', 'dynamic'],\n",
    "    hebbian_prune_frac=[0.99, 0.99],\n",
    "    magnitude_prune_frac=[0.0, 0.0],\n",
    "    sparsity=[0.98, 0.98],\n",
    "    update_nsteps=[50, 50],\n",
    "    prune_dims=tuple(),\n",
    "\n",
    "    # ----- Additional Validation -----\n",
    "    test_noise=False,\n",
    "    noise_level=0.1,\n",
    "\n",
    "    # ----- Debugging -----\n",
    "    debug_weights=True,\n",
    "    debug_sparse=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['dynamic', 'dynamic']\n",
      "defining the arg\n",
      "defining the arg\n",
      "defining the arg\n",
      "defining the arg\n",
      "defining the arg\n",
      "defining the arg\n",
      "defining the arg\n",
      "defining the arg\n",
      "[{'hebbian_prune_frac': 0.99, 'sparsity': 0.98, 'prune_dims': (), 'update_nsteps': 50}, {'hebbian_prune_frac': 0.99, 'sparsity': 0.98, 'prune_dims': (), 'update_nsteps': 50}]\n",
      "[('classifier.0', Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))), ('classifier.4', Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1)))]\n",
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 64, 28, 28]           1,664\n",
      "       BatchNorm2d-2           [-1, 64, 28, 28]               0\n",
      "         MaxPool2d-3           [-1, 64, 14, 14]               0\n",
      "        KWinners2d-4           [-1, 64, 14, 14]               0\n",
      "            Conv2d-5           [-1, 64, 10, 10]         102,464\n",
      "       BatchNorm2d-6           [-1, 64, 10, 10]               0\n",
      "         MaxPool2d-7             [-1, 64, 5, 5]               0\n",
      "        KWinners2d-8             [-1, 64, 5, 5]               0\n",
      "           Flatten-9                 [-1, 1600]               0\n",
      "           Linear-10                 [-1, 1000]       1,601,000\n",
      "      BatchNorm1d-11                 [-1, 1000]               0\n",
      "         KWinners-12                 [-1, 1000]               0\n",
      "           Linear-13                   [-1, 12]          12,012\n",
      "================================================================\n",
      "Total params: 1,717,140\n",
      "Trainable params: 1,717,140\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 1.11\n",
      "Params size (MB): 6.55\n",
      "Estimated Total Size (MB): 7.67\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "network = networks.gsc_conv_heb(config=config)\n",
    "summary(network, input_size=(1, 32, 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "named_convs = [\n",
    "    (name, layer) for name, layer in network.named_modules()\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('', GSCHeb(\n",
       "   (classifier): Sequential(\n",
       "     (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))\n",
       "     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "     (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "     (3): KWinners2d(channels=64, n=12544, percent_on=0.095, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "     (4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))\n",
       "     (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "     (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "     (7): KWinners2d(channels=64, n=1600, percent_on=0.125, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "     (8): Flatten()\n",
       "     (9): Linear(in_features=1600, out_features=1000, bias=True)\n",
       "     (10): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "     (11): KWinners(n=1000, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "     (12): Linear(in_features=1000, out_features=12, bias=True)\n",
       "   )\n",
       "   (classifier.0): DSConv2d(\n",
       "     1, 64, kernel_size=(5, 5), stride=(1, 1)\n",
       "     (grouped_conv): _NullConv(25, 25, kernel_size=(5, 5), stride=(1, 1), groups=25, bias=False)\n",
       "   )\n",
       "   (classifier.4): DSConv2d(\n",
       "     64, 64, kernel_size=(5, 5), stride=(1, 1)\n",
       "     (grouped_conv): _NullConv(102400, 1600, kernel_size=(5, 5), stride=(1, 1), groups=1600, bias=False)\n",
       "   )\n",
       " ))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "named_convs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('classifier', Sequential(\n",
       "   (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))\n",
       "   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "   (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "   (3): KWinners2d(channels=64, n=12544, percent_on=0.095, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "   (4): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))\n",
       "   (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "   (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "   (7): KWinners2d(channels=64, n=1600, percent_on=0.125, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "   (8): Flatten()\n",
       "   (9): Linear(in_features=1600, out_features=1000, bias=True)\n",
       "   (10): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "   (11): KWinners(n=1000, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "   (12): Linear(in_features=1000, out_features=12, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "named_convs[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('classifier.2',\n",
       " MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "named_convs[4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
